{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "g_a9QvUFVCUR"
   },
   "source": [
    "<h1>Chapter 4 - Text Classification</h1>\n",
    "<i>Classifying text with both representative and generative models</i>\n",
    "\n",
    "<a href=\"https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961\"><img src=\"https://img.shields.io/badge/Buy%20the%20Book!-grey?logo=amazon\"></a>\n",
    "<a href=\"https://www.oreilly.com/library/view/hands-on-large-language/9781098150952/\"><img src=\"https://img.shields.io/badge/O'Reilly-white.svg?logo=\"></a>\n",
    "<a href=\"https://github.com/HandsOnLLM/Hands-On-Large-Language-Models\"><img src=\"https://img.shields.io/badge/GitHub%20Repository-black?logo=github\"></a>\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HandsOnLLM/Hands-On-Large-Language-Models/blob/main/chapter04/Chapter%204%20-%20Text%20Classification.ipynb)\n",
    "\n",
    "---\n",
    "\n",
    "This notebook is for Chapter 4 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n",
    "\n",
    "---\n",
    "\n",
    "<a href=\"https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961\">\n",
    "<img src=\"https://raw.githubusercontent.com/HandsOnLLM/Hands-On-Large-Language-Models/main/images/book_cover.png\" width=\"350\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UBeVnXxQWy7-"
   },
   "source": [
    "# **Data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 9938,
     "status": "ok",
     "timestamp": 1709737297789,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "5phRS_z2U_3T",
    "outputId": "27f79175-2ec3-4922-e0ba-5bffd56c82cd"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
      "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
      "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
      "You will be able to reuse this secret in all of your notebooks.\n",
      "Please note that authentication is recommended but still optional to access public models or datasets.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 8530\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# Load our data\n",
    "data = load_dataset(\"rotten_tomatoes\")\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1709737297790,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "xJJmaJzHDLZv",
    "outputId": "04501032-aed3-425c-8d70-b069a34c280b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['the rock is destined to be the 21st century\\'s new \" conan \" and that he\\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .',\n",
       "  'things really get weird , though not particularly scary : the movie is all portent and no content .'],\n",
       " 'label': [1, 0]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[\"train\"][0, -1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xya5dfmVoR1R"
   },
   "source": [
    "# **Text Classification with Representation Models**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "co68g-Eloknf"
   },
   "source": [
    "## **Using a Task-specific Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 17052,
     "status": "ok",
     "timestamp": 1709737314828,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "ph-3T3XJopdN",
    "outputId": "62abf01e-ba0f-42fc-a8f3-fb467b02583b"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "Some weights of the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment-latest were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "/usr/local/lib/python3.10/dist-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated,  if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Path to our HF model\n",
    "model_path = \"cardiffnlp/twitter-roberta-base-sentiment-latest\"\n",
    "\n",
    "# Load model into pipeline\n",
    "pipe = pipeline(\n",
    "    model=model_path,\n",
    "    tokenizer=model_path,\n",
    "    return_all_scores=True,\n",
    "    device=\"cuda:0\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 37634,
     "status": "ok",
     "timestamp": 1709737352458,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "B2gbnL5Q69Y5",
    "outputId": "11c80fd6-0609-429f-caa4-443ee7298bbe"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1066/1066 [00:37<00:00, 28.25it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from transformers.pipelines.pt_utils import KeyDataset\n",
    "\n",
    "# Run inference\n",
    "y_pred = []\n",
    "for output in tqdm(pipe(KeyDataset(data[\"test\"], \"text\")), total=len(data[\"test\"])):\n",
    "    negative_score = output[0][\"score\"]\n",
    "    positive_score = output[2][\"score\"]\n",
    "    assignment = np.argmax([negative_score, positive_score])\n",
    "    y_pred.append(assignment)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "X0KyKHtqyjn3"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "\n",
    "def evaluate_performance(y_true, y_pred):\n",
    "    \"\"\"Create and print the classification report\"\"\"\n",
    "    performance = classification_report(\n",
    "        y_true, y_pred,\n",
    "        target_names=[\"Negative Review\", \"Positive Review\"]\n",
    "    )\n",
    "    print(performance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 16,
     "status": "ok",
     "timestamp": 1709737352459,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "fum3MTSyymlW",
    "outputId": "3a04041d-6a9f-44d3-ca15-c458bbb947ac"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.76      0.88      0.81       533\n",
      "Positive Review       0.86      0.72      0.78       533\n",
      "\n",
      "       accuracy                           0.80      1066\n",
      "      macro avg       0.81      0.80      0.80      1066\n",
      "   weighted avg       0.81      0.80      0.80      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Wr3WT4jzoNZE"
   },
   "source": [
    "## **Classification Tasks that Leverage Embeddings**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l8yuSP3heMzT"
   },
   "source": [
    "### Supervised Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 81,
     "referenced_widgets": [
      "2b933a85712a4571b2a01b5838ca3a9f",
      "b45b64fe077247ada77a3cc3955353c2",
      "0289cd2d988748e29548821e78b286cc",
      "0ae0629c02c247e698b13a665690d2af",
      "918cf345f3da443691d1e61ed68790d4",
      "8a627f49d3b1401eafbce793d6ff974b",
      "c56eb48aeb314c2fa8e7ce996c37c90a",
      "9b93708354e643039877c072090b830f",
      "fede446ab0ab46f3b919d305980d70a4",
      "a95f1bb3b57548bb87b5895b997349e5",
      "f5c82f510d134762bbe2f88b732bf9cb",
      "1d7dce3fb68f457c82ce7792809e55b5",
      "42bbdc9bcc464622903e4d2efe803f3b",
      "6910f4cf7f9a4f4aafaa7944fbd85200",
      "42d72d490e484787a6206d008dbc26b8",
      "cf3ed2d1780e4e26b701e8b70038278c",
      "a02435a08f1d4ab2aa67ea174f40b485",
      "89d34ec3f9b14dd1a655500aa451ba6a",
      "fcd513c6053d4a5888a354de040e47d0",
      "3d91c4caf4ec40c5b429a298a9803cbc",
      "42ba1c78e1ae43a1aff0eeaa2ada6a97",
      "c3870de37eb64d1ea358d1887b4445fa"
     ]
    },
    "executionInfo": {
     "elapsed": 26978,
     "status": "ok",
     "timestamp": 1709737379425,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "jGV9VS4bhq7f",
    "outputId": "47fc54ba-27a7-4043-e5a1-c4b60fa8de93"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b933a85712a4571b2a01b5838ca3a9f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/267 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d7dce3fb68f457c82ce7792809e55b5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/34 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# Load model\n",
    "model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
    "\n",
    "# Convert text to embeddings\n",
    "train_embeddings = model.encode(data[\"train\"][\"text\"], show_progress_bar=True)\n",
    "test_embeddings = model.encode(data[\"test\"][\"text\"], show_progress_bar=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 18,
     "status": "ok",
     "timestamp": 1709737379426,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "B5L5CLcOxIeA",
    "outputId": "689aac98-9a8c-4787-a246-ca95cd80cf9f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8530, 768)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 74
    },
    "executionInfo": {
     "elapsed": 16,
     "status": "ok",
     "timestamp": 1709737379426,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "8A7oTxoph6bn",
    "outputId": "3b88b393-04cd-493b-d396-91148a2cbe92"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression(random_state=42)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression(random_state=42)</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "LogisticRegression(random_state=42)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "# Train a Logistic Regression on our train embeddings\n",
    "clf = LogisticRegression(random_state=42)\n",
    "clf.fit(train_embeddings, data[\"train\"][\"label\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 15,
     "status": "ok",
     "timestamp": 1709737379426,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "tFvO9KhMokF7",
    "outputId": "6b93052c-0c1a-4e55-dc44-2280a2a66f5c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.85      0.86      0.85       533\n",
      "Positive Review       0.86      0.85      0.85       533\n",
      "\n",
      "       accuracy                           0.85      1066\n",
      "      macro avg       0.85      0.85      0.85      1066\n",
      "   weighted avg       0.85      0.85      0.85      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Predict previously unseen instances\n",
    "y_pred = clf.predict(test_embeddings)\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dwGIHxXpJgrC"
   },
   "source": [
    "**Tip!**  \n",
    "\n",
    "What would happen if we would not use a classifier at all? Instead, we can average the embeddings per class and apply cosine similarity to predict which classes match the documents best:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 499,
     "status": "ok",
     "timestamp": 1709737379911,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "3f_DnG1uJ7Sk",
    "outputId": "077a413d-05a3-429a-c098-2f2c73b4b53f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.85      0.84      0.84       533\n",
      "Positive Review       0.84      0.85      0.84       533\n",
      "\n",
      "       accuracy                           0.84      1066\n",
      "      macro avg       0.84      0.84      0.84      1066\n",
      "   weighted avg       0.84      0.84      0.84      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "# Average the embeddings of all documents in each target label\n",
    "df = pd.DataFrame(np.hstack([train_embeddings, np.array(data[\"train\"][\"label\"]).reshape(-1, 1)]))\n",
    "averaged_target_embeddings = df.groupby(768).mean().values\n",
    "\n",
    "# Find the best matching embeddings between evaluation documents and target embeddings\n",
    "sim_matrix = cosine_similarity(test_embeddings, averaged_target_embeddings)\n",
    "y_pred = np.argmax(sim_matrix, axis=1)\n",
    "\n",
    "# Evaluate the model\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wCWdzjMIjzx0"
   },
   "source": [
    "### Zero-shot Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YSj6CdAetsNp"
   },
   "outputs": [],
   "source": [
    "# Create embeddings for our labels\n",
    "label_embeddings = model.encode([\"A negative review\",  \"A positive review\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZEIN7XnbtsQJ"
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "# Find the best matching label for each document\n",
    "sim_matrix = cosine_similarity(test_embeddings, label_embeddings)\n",
    "y_pred = np.argmax(sim_matrix, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 5,
     "status": "ok",
     "timestamp": 1709737379912,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "u6LyeuEUxIbW",
    "outputId": "7fd0aa01-6de0-449d-c2cb-2534d37fa6b9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.78      0.77      0.78       533\n",
      "Positive Review       0.77      0.79      0.78       533\n",
      "\n",
      "       accuracy                           0.78      1066\n",
      "      macro avg       0.78      0.78      0.78      1066\n",
      "   weighted avg       0.78      0.78      0.78      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ox27Rg71zclg"
   },
   "source": [
    "**Tip!**  \n",
    "\n",
    "What would happen if you were to use different descriptions? Use **\"A very negative movie review\"** and **\"A very positive movie review\"** to see what happens!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4CC9iEGcuUit"
   },
   "source": [
    "## **Classification with Generative Models**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qFPPzUHoEESB"
   },
   "source": [
    "### Encoder-decoder Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nVbTUMktEfJ3"
   },
   "outputs": [],
   "source": [
    "# Load our model\n",
    "pipe = pipeline(\n",
    "    \"text2text-generation\",\n",
    "    model=\"google/flan-t5-small\",\n",
    "    device=\"cuda:0\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 7,
     "status": "ok",
     "timestamp": 1709737381485,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "o5nWQORcFlNn",
    "outputId": "3705715d-3286-41b1-e07f-29792826c9f6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label', 't5'],\n",
       "        num_rows: 8530\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label', 't5'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label', 't5'],\n",
       "        num_rows: 1066\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Prepare our data\n",
    "prompt = \"Is the following sentence positive or negative? \"\n",
    "data = data.map(lambda example: {\"t5\": prompt + example['text']})\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 41739,
     "status": "ok",
     "timestamp": 1709737423219,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "Nas574KFFSvR",
    "outputId": "b51b7694-3523-44a7-fcaa-48b2cfa2188a"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/1066 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1178: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.\n",
      "  warnings.warn(\n",
      "100%|██████████| 1066/1066 [00:40<00:00, 26.07it/s]\n"
     ]
    }
   ],
   "source": [
    "# Run inference\n",
    "y_pred = []\n",
    "for output in tqdm(pipe(KeyDataset(data[\"test\"], \"t5\")), total=len(data[\"test\"])):\n",
    "    text = output[0][\"generated_text\"]\n",
    "    y_pred.append(0 if text == \"negative\" else 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1709737423219,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "1Wk2i856GnCv",
    "outputId": "d1088769-6af2-46a2-b7fc-49769d0276aa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.83      0.85      0.84       533\n",
      "Positive Review       0.85      0.83      0.84       533\n",
      "\n",
      "       accuracy                           0.84      1066\n",
      "      macro avg       0.84      0.84      0.84      1066\n",
      "   weighted avg       0.84      0.84      0.84      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p4V9iq_EELWx"
   },
   "source": [
    "### ChatGPT for Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tES6HFOwNjF6"
   },
   "outputs": [],
   "source": [
    "import openai\n",
    "\n",
    "# Create client\n",
    "client = openai.OpenAI(api_key=\"YOUR_KEY_HERE\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dGiovm3wyCOz"
   },
   "outputs": [],
   "source": [
    "def chatgpt_generation(prompt, document, model=\"gpt-3.5-turbo-0125\"):\n",
    "    \"\"\"Generate an output based on a prompt and an input document.\"\"\"\n",
    "    messages=[\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful assistant.\"\n",
    "            },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\":   prompt.replace(\"[DOCUMENT]\", document)\n",
    "            }\n",
    "    ]\n",
    "    chat_completion = client.chat.completions.create(\n",
    "      messages=messages,\n",
    "      model=model,\n",
    "      temperature=0\n",
    "    )\n",
    "    return chat_completion.choices[0].message.content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "executionInfo": {
     "elapsed": 647,
     "status": "ok",
     "timestamp": 1709739851329,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "qL_kMwQEvMcd",
    "outputId": "837cf556-4dab-482f-ad66-9d6fb06bcda3"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'1'"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define a prompt template as a base\n",
    "prompt = \"\"\"Predict whether the following document is a positive or negative movie review:\n",
    "\n",
    "[DOCUMENT]\n",
    "\n",
    "If it is positive return 1 and if it is negative return 0. Do not give any other answers.\n",
    "\"\"\"\n",
    "\n",
    "# Predict the target using GPT\n",
    "document = \"unpretentious , charming , quirky , original\"\n",
    "chatgpt_generation(prompt, document)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ea-8XYpY3jp6"
   },
   "source": [
    "The next step would be to run one of OpenAI's model against the entire evaluation dataset. However, only run this when you have sufficient tokens as this will call the API for the entire test dataset (1066 records)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 814003,
     "status": "ok",
     "timestamp": 1709738237754,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "gEyGElIv25Aq",
    "outputId": "20b983d7-c154-4bfd-d652-2ed42f425c54"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1066/1066 [13:34<00:00,  1.31it/s]\n"
     ]
    }
   ],
   "source": [
    "# You can skip this if you want to save your (free) credits\n",
    "predictions = [chatgpt_generation(prompt, doc) for doc in tqdm(data[\"test\"][\"text\"])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 13,
     "status": "ok",
     "timestamp": 1709738237755,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -60
    },
    "id": "B07O8wtZ05x1",
    "outputId": "48ab3146-85e7-4d92-e550-04bd98e39cb9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                 precision    recall  f1-score   support\n",
      "\n",
      "Negative Review       0.87      0.97      0.92       533\n",
      "Positive Review       0.96      0.86      0.91       533\n",
      "\n",
      "       accuracy                           0.91      1066\n",
      "      macro avg       0.92      0.91      0.91      1066\n",
      "   weighted avg       0.92      0.91      0.91      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Extract predictions\n",
    "y_pred = [int(pred) for pred in predictions]\n",
    "\n",
    "# Evaluate performance\n",
    "evaluate_performance(data[\"test\"][\"label\"], y_pred)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyOVGHefh7zey8Fc1ZhM0hvU",
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "0289cd2d988748e29548821e78b286cc": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_9b93708354e643039877c072090b830f",
      "max": 267,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_fede446ab0ab46f3b919d305980d70a4",
      "value": 267
     }
    },
    "0ae0629c02c247e698b13a665690d2af": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a95f1bb3b57548bb87b5895b997349e5",
      "placeholder": "​",
      "style": "IPY_MODEL_f5c82f510d134762bbe2f88b732bf9cb",
      "value": " 267/267 [00:19&lt;00:00, 30.88it/s]"
     }
    },
    "1d7dce3fb68f457c82ce7792809e55b5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_42bbdc9bcc464622903e4d2efe803f3b",
       "IPY_MODEL_6910f4cf7f9a4f4aafaa7944fbd85200",
       "IPY_MODEL_42d72d490e484787a6206d008dbc26b8"
      ],
      "layout": "IPY_MODEL_cf3ed2d1780e4e26b701e8b70038278c"
     }
    },
    "2b933a85712a4571b2a01b5838ca3a9f": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_b45b64fe077247ada77a3cc3955353c2",
       "IPY_MODEL_0289cd2d988748e29548821e78b286cc",
       "IPY_MODEL_0ae0629c02c247e698b13a665690d2af"
      ],
      "layout": "IPY_MODEL_918cf345f3da443691d1e61ed68790d4"
     }
    },
    "3d91c4caf4ec40c5b429a298a9803cbc": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "42ba1c78e1ae43a1aff0eeaa2ada6a97": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "42bbdc9bcc464622903e4d2efe803f3b": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a02435a08f1d4ab2aa67ea174f40b485",
      "placeholder": "​",
      "style": "IPY_MODEL_89d34ec3f9b14dd1a655500aa451ba6a",
      "value": "Batches: 100%"
     }
    },
    "42d72d490e484787a6206d008dbc26b8": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_42ba1c78e1ae43a1aff0eeaa2ada6a97",
      "placeholder": "​",
      "style": "IPY_MODEL_c3870de37eb64d1ea358d1887b4445fa",
      "value": " 34/34 [00:02&lt;00:00, 19.19it/s]"
     }
    },
    "6910f4cf7f9a4f4aafaa7944fbd85200": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_fcd513c6053d4a5888a354de040e47d0",
      "max": 34,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_3d91c4caf4ec40c5b429a298a9803cbc",
      "value": 34
     }
    },
    "89d34ec3f9b14dd1a655500aa451ba6a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "8a627f49d3b1401eafbce793d6ff974b": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "918cf345f3da443691d1e61ed68790d4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "9b93708354e643039877c072090b830f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a02435a08f1d4ab2aa67ea174f40b485": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a95f1bb3b57548bb87b5895b997349e5": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "b45b64fe077247ada77a3cc3955353c2": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_8a627f49d3b1401eafbce793d6ff974b",
      "placeholder": "​",
      "style": "IPY_MODEL_c56eb48aeb314c2fa8e7ce996c37c90a",
      "value": "Batches: 100%"
     }
    },
    "c3870de37eb64d1ea358d1887b4445fa": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "c56eb48aeb314c2fa8e7ce996c37c90a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "cf3ed2d1780e4e26b701e8b70038278c": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f5c82f510d134762bbe2f88b732bf9cb": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "fcd513c6053d4a5888a354de040e47d0": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fede446ab0ab46f3b919d305980d70a4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
